{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learn how to build a simple multi-layer-perceptron on the MNIST dataset\n",
    "\n",
    "MNIST from: https://github.com/FluxML/model-zoo/blob/master/mnist/mlp.jl\n",
    "\n",
    "Let's start by loading `Flux`, importing a few things from `Flux` explicitly, and bringing the `repeated` function into our scope."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "using Flux, Flux.Data.MNIST\n",
    "using Flux: onehotbatch, argmax, crossentropy, throttle\n",
    "using Base.Iterators: repeated"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now store all the MNIST images in `imgs` and take a peak into this vector to see what the data looks like"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "imgs = MNIST.images()\n",
    "imgs[3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's look at the type of an individual image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "typeof(imgs[3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Reorganizing our array of images\n",
    "\n",
    "We see this is a 2D array that stores `ColorTypes`. To work more easily with this data, let's convert all `ColorTypes` to floating point numbers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fpt_imgs = float.(imgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can see what `imgs[3]` looks like as an array of floats, rather than as an array of colors!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fpt_imgs[3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Let's stack the images to create one large 2D array, `X`, that stores the data for each image as a column.**\n",
    "\n",
    "To do this, we can **first** use `reshape` to unravel each image, creating a 1D array (`Vector`) of floats from a 2D array (`Matrix`) of floats."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "unraveled_fpt_imgs = reshape.(fpt_imgs, :);\n",
    "typeof(unraveled_fpt_imgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(Note that `Vector` is an alias for a 1D `Array`.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "Vector"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This makes `unraveled_fpt_imgs` a `Vector` of `Vector`s where `imgs[3]` is now"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "unraveled_fpt_imgs[3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After using `reshape` to get a `Vector` of `Vector`s, we can use `hcat` to build a `Matrix`, `X`, from `unraveled_fpt_imgs` where the `Vector`s stored in `unraveled_fpt_imgs` will become the columns of `X`.\n",
    "\n",
    "Note that we're using the \"splat\" command below, `...`, which allows you to pass all the elements of an object to a function, rather than just passing the object itself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X = hcat(unraveled_fpt_imgs...)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### How to go back to images from this 2D `Array`\n",
    "\n",
    "So now each column in X is an image reshaped to a vector of floating points. Let's pick one column and see what the digit is.\n",
    "\n",
    "Let's try to view the second image in the original array, `imgs`, by taking the second column of `X`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "onefigure = X[:,2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll `reshape` this array to a 2D, 28x28 array,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "t1 = reshape(onefigure,28,28)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and finally use `colorview` from the `Images` package to view the handwritten digit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "using Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "colorview(Gray, t1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Our data is in working order!*\n",
    "\n",
    "For our machine to learn the digit with which each image is associated, we'll need to train it using correct answers. Therefore we'll make use of the `labels` associated with these images from MNIST."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "labels = MNIST.labels() # the true labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One-hot-encode the labels with `onehotbatch`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "Y = onehotbatch(labels, 0:9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "which gives a binary indicator vector for each figure"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "m = Chain(\n",
    "  Dense(28^2, 32, relu),\n",
    "  Dense(32, 10),\n",
    "  softmax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the loss functions and accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "loss(x, y) = Flux.crossentropy(m(x), y)\n",
    "accuracy(x, y) = mean(argmax(m(x)) .== argmax(y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use `X` to create our training data and then declare our evaluation function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dataset = repeated((X, Y), 200)\n",
    "evalcb = () -> @show(loss(X, Y))\n",
    "opt = ADAM(Flux.params(m))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So far, we have defined our training data and our evaluation functions.\n",
    "\n",
    "Let's take a look at the function signature of Flux.train!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "?Flux.train!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Now we can train our model and look at the accuracy thereafter.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 10))\n",
    "\n",
    "accuracy(X, Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we've trained our model, let's create test data, `tX`, "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "tX = hcat(float.(reshape.(MNIST.images(:test), :))...)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and run our model on one of the images from `tX`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_image = m(tX[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "indmax(test_image) - 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The largest element of `test_image` is the 8th element, so our model says that test_image is a \"7\".\n",
    "\n",
    "Now we can look at the original image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "using Images\n",
    "t1 = reshape(tX[:,1],28,28)\n",
    "colorview(Gray, t1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and there we have it!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.6.0",
   "language": "julia",
   "name": "julia-0.6"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
